{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4545c6ad",
   "metadata": {},
   "source": [
    "# PC Algorithm for Distribution Shift Detection in Tabular Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4abfea4",
   "metadata": {},
   "source": [
    "PC algorithm can detect the origins of distribution shifts in tabular, continous/discrete data with the help of domain index variable. The algorithm uses the PC algorithm to estimate the causal graph, by treating distribution shifts as intervention of the domain index on the root cause node, and PC can use  conditional independence tests to quickly recover the causal graph and detec the root cause of anomaly. Note that the algorithm supports both discrete and continuous variables, and can handle nonlinear relationships by converting the continous variables into discrete ones using K-means clustering and using discrete PC algorithm instead for CI test and causal discovery."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "627c40e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from causalai.data.data_generator import DataGenerator\n",
    "from causalai.application import TabularDistributionShiftDetector\n",
    "from causalai.application.common import distshift_detector_preprocess"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09d9b8bb",
   "metadata": {},
   "source": [
    "### Generate tabular data with two domains with distribution shifts"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2924dcb",
   "metadata": {},
   "source": [
    "We add distribution shifts on the **node b**. Because of the causal influnences, the anomaly on node b proporgates along the causal graph to node d as well. However, node d is not the cause of distribution shifts. Our algorithm is supposed to only return node b as the root cause of anomaly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "cc751b74",
   "metadata": {},
   "outputs": [],
   "source": [
    "fn_normal = lambda x:x\n",
    "fn_abnormal = lambda x:x+1\n",
    "\n",
    "coef = 1.0\n",
    "sem_normal = {\n",
    "        'a': [], \n",
    "        'b': [('a', coef, fn_normal)], \n",
    "        'c': [('a', coef, fn_normal)],\n",
    "        'd': [('b', coef, fn_normal), ('c', coef, fn_normal)]\n",
    "        }\n",
    "\n",
    "sem_abnormal = {\n",
    "        'a': [], \n",
    "        'b': [('a', coef, fn_abnormal)], \n",
    "        'c': [('a', coef, fn_normal)],\n",
    "        'd': [('b', coef, fn_normal), ('c', coef, fn_normal)]\n",
    "        }\n",
    "\n",
    "T = 1000\n",
    "data_array_normal, var_names, graph_gt = DataGenerator(sem_normal, T=T, seed=0, discrete=False)\n",
    "data_array_abnormal, var_names, graph_gt = DataGenerator(sem_abnormal, T=T, seed=1, discrete=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "17fa640c",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_normal = pd.DataFrame(data=data_array_normal, columns=var_names)\n",
    "df_abnormal = pd.DataFrame(data=data_array_abnormal, columns=var_names)\n",
    "c_idx = np.array([0]*T + [1]*T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d82b30c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_obj, var_names = distshift_detector_preprocess(\n",
    "    data=[df_normal, df_abnormal],\n",
    "    domain_index=c_idx,\n",
    "    domain_index_name='domain_index',\n",
    "    n_states=2\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fe623840",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['a', 'b', 'c', 'd', 'domain_index']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "var_names"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b02278f",
   "metadata": {},
   "source": [
    "### Run the tabular distribution shift detector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5652e2f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = TabularDistributionShiftDetector(\n",
    "    data_obj=data_obj,\n",
    "    var_names=var_names,\n",
    "    domain_index_name='domain_index',\n",
    "    prior_knowledge=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4887d486",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The distribution shifts are from the nodes: {'b'}\n"
     ]
    }
   ],
   "source": [
    "root_causes, graph = model.run(\n",
    "    pvalue_thres=0.01,\n",
    "    max_condition_set_size=4,\n",
    "    return_graph=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "90f35c5c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'b'}\n"
     ]
    }
   ],
   "source": [
    "print(root_causes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "57a73f4c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'a': {'c', 'b', 'd'}, 'b': {'a', 'd'}, 'c': {'a', 'd'}, 'd': {'c', 'b', 'a'}, 'domain_index': {'b'}}\n"
     ]
    }
   ],
   "source": [
    "print(graph)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b59ee7bc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "CausalAI",
   "language": "python",
   "name": "causal_ai_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
